inp = torch.randn((1,3,384,384))
Fastai encoder expects a function as it's first argument, where timm expects a string. Also, fastai defaults to concat pooling, aka catavgmax
in timm. With timm's selective pooling any PoolingType
can used. Experiments show that concat pooling is better on average so it is set as our default.
For any other pool_type
fastai uses AdaptiveAvgPool2d
, for timm you can choose from the remaining PoolingType
.
fastai_encoder = create_fastai_encoder(xresnet34)
out = fastai_encoder(inp); out.shape
fastai_encoder = create_fastai_encoder(xresnet34, pool_type=False)
out = fastai_encoder(inp); out.shape
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False)
out = model(inp); out.shape
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape
model = create_encoder("xresnet34", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape
model = create_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape
Vision Transformer is a special case which uses Layernorm
.
vit_model = create_timm_encoder("vit_large_patch16_384", pretrained=False)
out = vit_model(inp); out.shape
create_mlp_module(1024,4096,128)
create_mlp_module(1024,4096,128,nlayers=3)
create_mlp_module(1024,4096,128,bn=True)
create_mlp_module(1024,4096,128,bn=True,nlayers=3)
inp = torch.randn((2,3,384,384))
encoder = create_encoder("xresnet34", pretrained=False)
out = encoder(inp)
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
encoder = create_encoder("vit_large_patch16_384", pretrained=False)
out = encoder(inp)
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
create_model
can be used to create models for classification, for example quickly creating a model for downstream classification training.
_splitter
can be passed to Learner(...,splitter=splitter_func)
. This can be used to freeze or unfreeze encoder layers, in this case first parameter group is the encoder and second parameter group is the classification head. Simply by indexing to model[0] and model[1] we can access encoder and classification head modules.
model = create_model("xresnet34", 10, pretrained=False)
model[1]
with torch.no_grad(): print(model(inp))
model = create_model("vit_large_patch16_384", 10, pretrained=False, use_bn=False, first_bn=False, bn_final=False)
model[1]
with torch.no_grad(): print(model(inp))
Gradient Checkpointing
For memory conservation, to train with larger image resolution and/or batch size. It's compatible with all timm ResNet
, EfficientNet
and VisionTransformer
models, and fastai models. But it should be easy to implement for any encoder model that you are using.
This is a current fix for using gradient checkpointing with autocast / to_fp16()
https://github.com/pytorch/pytorch/pull/49757/files
L(timm.list_models("*resnet50*"))[-10:]
encoder = create_encoder("seresnet50", pretrained=False)
encoder = CheckpointResNet(encoder, checkpoint_nchunks=4)
out = encoder(inp)
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
L(timm.list_models("*efficientnet*"))[-10:]
encoder = create_encoder("tf_efficientnet_b0_ns", pretrained=False)
encoder = CheckpointEfficientNet(encoder, checkpoint_nchunks=4)
out = encoder(inp)
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))
encoder = create_encoder("xresnet34", pretrained=False)
encoder = CheckpointSequential(encoder, checkpoint_nchunks=4)
out = encoder(inp)
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))